1 Deep learning Classification

1.1 Data Extraction

library(rvest)
url <- "https://wiki.socr.umich.edu/index.php/SOCR_Data_July2009_ID_NI"
web_data <- read_html(url)
df <- web_data %>% html_table(fill = TRUE) %>% .[[1]]
str(df)
## tibble [672 × 13] (S3: tbl_df/tbl/data.frame)
##  $ Subject_ID: int [1:672] 1 1 1 1 1 1 1 1 1 1 ...
##  $ Group     : chr [1:672] "AD" "AD" "AD" "AD" ...
##  $ MMSE      : int [1:672] 21 21 21 21 21 21 21 21 21 21 ...
##  $ CDR       : num [1:672] 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 ...
##  $ Sex       : logi [1:672] FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ Age       : int [1:672] 82 82 82 82 82 82 82 82 82 82 ...
##  $ TBV       : int [1:672] 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 ...
##  $ GMV       : int [1:672] 522930 522930 522930 522930 522930 522930 522930 522930 522930 522930 ...
##  $ WMV       : int [1:672] 247583 247583 247583 247583 247583 247583 247583 247583 247583 247583 ...
##  $ CSFV      : int [1:672] 281194 281194 281194 281194 281194 281194 281194 281194 281194 281194 ...
##  $ ROI       : int [1:672] 1 1 1 1 2 2 2 2 3 3 ...
##  $ Measure   : chr [1:672] "SA" "SI" "CV" "FD" ...
##  $ Value     : num [1:672] 14704.29 0.43 0.08 2.15 9084.68 ...
head(df)
## # A tibble: 6 × 13
##   Subject_ID Group  MMSE   CDR Sex     Age     TBV    GMV    WMV   CSFV   ROI
##        <int> <chr> <int> <dbl> <lgl> <int>   <int>  <int>  <int>  <int> <int>
## 1          1 AD       21   1.2 FALSE    82 1051706 522930 247583 281194     1
## 2          1 AD       21   1.2 FALSE    82 1051706 522930 247583 281194     1
## 3          1 AD       21   1.2 FALSE    82 1051706 522930 247583 281194     1
## 4          1 AD       21   1.2 FALSE    82 1051706 522930 247583 281194     1
## 5          1 AD       21   1.2 FALSE    82 1051706 522930 247583 281194     2
## 6          1 AD       21   1.2 FALSE    82 1051706 522930 247583 281194     2
## # ℹ 2 more variables: Measure <chr>, Value <dbl>

1.2 Pool the Cohort

df$Group <- ifelse(df$Group %in% c("MCI", "AD"), "Patients", df$Group)

# Assuming your dataframe is named 'df' and the problematic columns are 'Group', 'Sex', and 'Measure'
df$Measure <- as.numeric(as.factor(df$Measure))
df$Sex <- as.numeric(as.factor(df$Sex))

str(df)
## tibble [672 × 13] (S3: tbl_df/tbl/data.frame)
##  $ Subject_ID: int [1:672] 1 1 1 1 1 1 1 1 1 1 ...
##  $ Group     : chr [1:672] "Patients" "Patients" "Patients" "Patients" ...
##  $ MMSE      : int [1:672] 21 21 21 21 21 21 21 21 21 21 ...
##  $ CDR       : num [1:672] 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 1.2 ...
##  $ Sex       : num [1:672] 1 1 1 1 1 1 1 1 1 1 ...
##  $ Age       : int [1:672] 82 82 82 82 82 82 82 82 82 82 ...
##  $ TBV       : int [1:672] 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 1051706 ...
##  $ GMV       : int [1:672] 522930 522930 522930 522930 522930 522930 522930 522930 522930 522930 ...
##  $ WMV       : int [1:672] 247583 247583 247583 247583 247583 247583 247583 247583 247583 247583 ...
##  $ CSFV      : int [1:672] 281194 281194 281194 281194 281194 281194 281194 281194 281194 281194 ...
##  $ ROI       : int [1:672] 1 1 1 1 2 2 2 2 3 3 ...
##  $ Measure   : num [1:672] 3 4 1 2 3 4 1 2 3 4 ...
##  $ Value     : num [1:672] 14704.29 0.43 0.08 2.15 9084.68 ...
head(df)
## # A tibble: 6 × 13
##   Subject_ID Group     MMSE   CDR   Sex   Age     TBV    GMV    WMV   CSFV   ROI
##        <int> <chr>    <int> <dbl> <dbl> <int>   <int>  <int>  <int>  <int> <int>
## 1          1 Patients    21   1.2     1    82 1051706 522930 247583 281194     1
## 2          1 Patients    21   1.2     1    82 1051706 522930 247583 281194     1
## 3          1 Patients    21   1.2     1    82 1051706 522930 247583 281194     1
## 4          1 Patients    21   1.2     1    82 1051706 522930 247583 281194     1
## 5          1 Patients    21   1.2     1    82 1051706 522930 247583 281194     2
## 6          1 Patients    21   1.2     1    82 1051706 522930 247583 281194     2
## # ℹ 2 more variables: Measure <dbl>, Value <dbl>

This code pools certain groups into a single category, converts two columns from categorical to numeric coding for analysis, and then displays the structure of the modified dataframe.

1.3 Multi-layer Perceptron Classifier

#install.packages("remotes")
#remotes::install_github("rstudio/tensorflow", force = T)
#install.packages("reticulate")
library(reticulate)
use_python("C:/Users/Jun/AppData/Local/Programs/Python/Python38/python.exe")
#py_install("Pillow", pip = TRUE)
#tensorflow::install_tensorflow()
library(tensorflow)
#tensorflow::install_tensorflow(extra_packages = "pillow")
#tensorflow::install_tensorflow(version = "2.13.*")
#virtualenv_create("r-tensorflow")
#use_virtualenv("r-tensorflow", required = TRUE)

#devtools::install_github("rstudio/keras")
library(keras)
#install_keras()

py_config()
## python:         C:/Users/Jun/AppData/Local/Programs/Python/Python38/python.exe
## libpython:      C:/Users/Jun/AppData/Local/Programs/Python/Python38/python38.dll
## pythonhome:     C:/Users/Jun/AppData/Local/Programs/Python/Python38
## version:        3.8.2 (tags/v3.8.2:7b3ab59, Feb 25 2020, 23:03:10) [MSC v.1916 64 bit (AMD64)]
## Architecture:   64bit
## numpy:          C:/Users/Jun/AppData/Local/Programs/Python/Python38/Lib/site-packages/numpy
## numpy_version:  1.24.3
## tensorflow:     C:\Users\Jun\AppData\Local\Programs\Python\Python38\lib\site-packages\tensorflow\__init__.p
## 
## NOTE: Python version was forced by use_python() function
set.seed(2024)  # for reproducibility

features <- df[,-c(1,2)]
target <- df[, 2]
target <- ifelse(target == "Patients", 1, 0)  # Binary encoding for target

indices <- sample(1:nrow(features), size = 0.8 * nrow(features), replace = FALSE)
train_x <- features[indices, ]
train_y <- target[indices,]
test_x <- features[-indices, ]
test_y <- target[-indices,]

# Define the model
model <- keras_model_sequential() %>%
  layer_dense(units = 8, activation = 'relu', input_shape = c(ncol(train_x))) %>%
  layer_dense(units = 1, activation = 'sigmoid')

# Compile the model
model %>% compile(
  loss = 'binary_crossentropy',
  optimizer = 'adam',
  metrics = 'accuracy'
)
# Assuming `train_x` is a data frame, convert it to a matrix
train_x_matrix <- as.matrix(train_x)
train_y_matrix <- as.matrix(train_y)  

history <- model %>% fit(
  train_x_matrix,
  train_y_matrix,
  epochs = 15,
  batch_size = 5,
  validation_split = 0.2
)
## Epoch 1/15
## 86/86 - 1s - loss: 56296.5625 - accuracy: 0.4709 - val_loss: 667.0047 - val_accuracy: 0.7222 - 963ms/epoch - 11ms/step
## Epoch 2/15
## 86/86 - 0s - loss: 39.2112 - accuracy: 0.9371 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 238ms/epoch - 3ms/step
## Epoch 3/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 194ms/epoch - 2ms/step
## Epoch 4/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 206ms/epoch - 2ms/step
## Epoch 5/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 208ms/epoch - 2ms/step
## Epoch 6/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 209ms/epoch - 2ms/step
## Epoch 7/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 221ms/epoch - 3ms/step
## Epoch 8/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 198ms/epoch - 2ms/step
## Epoch 9/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 191ms/epoch - 2ms/step
## Epoch 10/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 184ms/epoch - 2ms/step
## Epoch 11/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 185ms/epoch - 2ms/step
## Epoch 12/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 200ms/epoch - 2ms/step
## Epoch 13/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 190ms/epoch - 2ms/step
## Epoch 14/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 192ms/epoch - 2ms/step
## Epoch 15/15
## 86/86 - 0s - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000 - 192ms/epoch - 2ms/step
# Predict on test set
predictions <- model %>% predict(as.matrix(test_x))
## 5/5 - 0s - 77ms/epoch - 15ms/step
predicted_classes <- ifelse(predictions > 0.5, 1, 0)

# Evaluation metrics
confusion_matrix <- table(Predicted = predicted_classes, Actual = as.matrix(test_y))
accuracy <- sum(diag(confusion_matrix)) / sum(confusion_matrix)
sensitivity <- confusion_matrix["1","1"] / sum(confusion_matrix[, "1"])  # TP / (TP + FN)
specificity <- confusion_matrix["0","0"] / sum(confusion_matrix[, "0"])  # TN / (TN + FP)
odds_ratio <- (sensitivity / (1 - sensitivity)) / (specificity / (1 - specificity))
LOR <- ifelse(odds_ratio == 0, -Inf, ifelse(odds_ratio == Inf, Inf, log(odds_ratio)))
auc <- pROC::auc(pROC::roc(response = test_y, predictor = as.numeric(predictions)))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
# Display results
list(
  ConfusionMatrix = confusion_matrix,
  Accuracy = accuracy,
  Sensitivity = sensitivity,
  Specificity = specificity,
  OddsRatio = odds_ratio,
  LOR = LOR,
  AUC = auc
)
## $ConfusionMatrix
##          Actual
## Predicted  0  1
##         0 46  0
##         1  0 89
## 
## $Accuracy
## [1] 1
## 
## $Sensitivity
## [1] 1
## 
## $Specificity
## [1] 1
## 
## $OddsRatio
## [1] NaN
## 
## $LOR
## [1] NA
## 
## $AUC
## Area under the curve: 1
plot(history)

When comparing the results of the binary classification to the multi-class classification under the same conditions (15 epochs and a batch size of 5), the perfect metrics (100% accuracy and an AUC of 1) achieved at epoch 3 are reported by the binary classification model might indeed raise concerns about overfitting, or it could be an indication that the test set was not challenging enough or that the data was not representative of the real-world complexity.

Binary Classification Model Results Confusion Matrix: The matrix shows that the model predicted all the instances correctly. There are no false positives or false negatives.

Accuracy: The accuracy is 1 (or 100%), which means the model correctly classified all the instances in the test dataset.

Sensitivity: Also known as recall or true positive rate, is 1, indicating that the model identified all positive instances correctly.

Specificity: This is the true negative rate, and it is also 1, meaning all negative instances were correctly identified by the model.

Odds Ratio: This is not applicable (NaN) because there are no false positives or false negatives, leading to a division by zero in the calculation.

Log Odds Ratio (LOR): This is NA, which is a consequence of the odds ratio being NaN. Without false positives or false negatives, the odds ratio cannot be computed.

Area Under the Curve (AUC): With a value of 1, this indicates perfect discrimination by the model between the positive and negative classes.

1.4 Visualization

library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.2.3
# 1. Scatter Plot between MMSE and Age
ggplot(df, aes(x = Age, y = Group)) +
  geom_point() +
  labs(title = "Scatter Plot of Group vs Age", x = "Age", y = "MMSE")

# 2. Boxplot of MMSE across different Groups
ggplot(df, aes(x = Group, y = Age)) +
  geom_boxplot() +
  labs(title = "Boxplot of Age for Different Groups", x = "Group", y = "AGE")

# 3. Histogram of Age
ggplot(df, aes(x = Age)) +
  geom_histogram(bins = 30) +
  labs(title = "Histogram of Age", x = "Age", y = "Count")

# 4. Density Plot of Age for AD vs NC
ggplot(df, aes(x = Age, fill = Group)) +
  geom_density(alpha = 0.7) +
  labs(title = "Density Plot of Age for Different Groups", x = "Age", y = "Density")

# 5. Bar Plot for the count of different Groups
ggplot(df, aes(x = Group)) +
  geom_bar() +
  labs(title = "Count of Different Groups", x = "Group", y = "Count")

# Barplot of predictions
# Create a dataframe for plotting
# Assuming 'predictions' is a vector with predicted classes
prediction_table <- table(predictions)

# Now we convert it into a data frame for plotting with ggplot2
prediction_df <- as.data.frame(prediction_table)

# Give proper names to the columns
names(prediction_df) <- c("Class", "Count")

# Use ggplot2 to create the bar graph
library(ggplot2)

ggplot(prediction_df, aes(x = Class, y = Count, fill = Class)) +
  geom_bar(stat = "identity") + # "identity" to use the actual values in the 'Count' column
  theme_minimal() +
  labs(title = "Bar Graph of Predicted Classes", x = "Class", y = "Count") +
  scale_fill_brewer(palette = "Set2") + # Optional: to use different colors for each class 
  scale_x_discrete(labels = c("0" = "NC", "1" = "Patients")) 

# Visualizing model structure
summary(model)
## <pointer: 0x0>

The data indicates that individuals aged 80 and above are predominantly classified as patients, with a total count of 448, compared to 224 classified as non-cases (NC). Within the predicted classifications, there are 128 individuals identified as patients in contrast to 74 as non-cases (NC).

The model summary shows a sequential neural network with two layers:

dense_1: A fully connected layer with 8 neurons, outputting 11 features, with 96 parameters (11x 8 +8).

dense: The output layer with a single neuron, suitable for binary classification, with 9 parameters (8 x 1 +1).

The network has a total of 105 trainable parameters, indicating it’s a relatively simple model that will learn from the data during training. The activation functions used suggest the model is designed to capture non-linear patterns and output probabilities for binary classification.

1.5 Multi-Class Modeling

set.seed(2024)
df2 <- web_data %>% html_table(fill = TRUE) %>% .[[1]]
df2$Group <- as.numeric(as.factor(df2$Group))-1
df2$Measure <- as.numeric(as.factor(df2$Measure))
df2$Sex <- as.numeric(as.factor(df2$Sex))
features <- df2[,-c(1,2)]
target <- df2[, 2]

indices <- sample(1:nrow(features), size = 0.8 * nrow(features), replace = FALSE)
train_x <- features[indices, ]
train_y <- target[indices,]
test_x <- features[-indices, ]
test_y <- target[-indices,]

# Define the model for multi-class classification
model <- keras_model_sequential() %>%
  layer_dense(units = 8, activation = 'relu', input_shape = c(ncol(train_x))) %>%
  layer_dense(units = 3, activation = 'softmax')  # Adjust for three classes

# Compile the model for multi-class classification
model %>% compile(
  loss = 'categorical_crossentropy',  # Change to categorical crossentropy
  optimizer = 'adam',
  metrics = 'accuracy'
)

# Convert 'train_y' and 'test_y' to one-hot encoded format
train_y_onehot <- to_categorical(as.matrix(train_y))
test_y_onehot <- to_categorical(as.matrix(test_y))

# Fit the model on one-hot encoded targets
history <- model %>% fit(
  as.matrix(train_x),
  train_y_onehot,
  epochs = 15,
  batch_size = 5,
  validation_split = 0.2
)
## Epoch 1/15
## 86/86 - 1s - loss: 621639.9375 - accuracy: 0.3333 - val_loss: 463797.7812 - val_accuracy: 0.3704 - 829ms/epoch - 10ms/step
## Epoch 2/15
## 86/86 - 0s - loss: 375609.9688 - accuracy: 0.3333 - val_loss: 260615.2344 - val_accuracy: 0.3704 - 219ms/epoch - 3ms/step
## Epoch 3/15
## 86/86 - 0s - loss: 162118.6562 - accuracy: 0.3333 - val_loss: 29400.5801 - val_accuracy: 0.3704 - 237ms/epoch - 3ms/step
## Epoch 4/15
## 86/86 - 0s - loss: 9215.8760 - accuracy: 0.2051 - val_loss: 5075.2012 - val_accuracy: 0.3704 - 194ms/epoch - 2ms/step
## Epoch 5/15
## 86/86 - 0s - loss: 3057.4426 - accuracy: 0.2308 - val_loss: 3777.0122 - val_accuracy: 0.6481 - 192ms/epoch - 2ms/step
## Epoch 6/15
## 86/86 - 0s - loss: 2795.2437 - accuracy: 0.2098 - val_loss: 2641.6807 - val_accuracy: 0.0000e+00 - 222ms/epoch - 3ms/step
## Epoch 7/15
## 86/86 - 0s - loss: 2609.0393 - accuracy: 0.1562 - val_loss: 2646.5647 - val_accuracy: 0.2963 - 204ms/epoch - 2ms/step
## Epoch 8/15
## 86/86 - 0s - loss: 2232.0449 - accuracy: 0.2051 - val_loss: 2059.3772 - val_accuracy: 0.3704 - 191ms/epoch - 2ms/step
## Epoch 9/15
## 86/86 - 0s - loss: 2008.0107 - accuracy: 0.1981 - val_loss: 2067.1299 - val_accuracy: 0.2870 - 195ms/epoch - 2ms/step
## Epoch 10/15
## 86/86 - 0s - loss: 1645.1154 - accuracy: 0.2168 - val_loss: 1322.6821 - val_accuracy: 0.0093 - 237ms/epoch - 3ms/step
## Epoch 11/15
## 86/86 - 0s - loss: 1513.8796 - accuracy: 0.2331 - val_loss: 1268.5308 - val_accuracy: 0.2778 - 191ms/epoch - 2ms/step
## Epoch 12/15
## 86/86 - 0s - loss: 1283.9553 - accuracy: 0.2471 - val_loss: 1954.7231 - val_accuracy: 0.3704 - 200ms/epoch - 2ms/step
## Epoch 13/15
## 86/86 - 0s - loss: 1103.4797 - accuracy: 0.3473 - val_loss: 979.8542 - val_accuracy: 0.3704 - 209ms/epoch - 2ms/step
## Epoch 14/15
## 86/86 - 0s - loss: 654.7548 - accuracy: 0.2984 - val_loss: 710.6323 - val_accuracy: 0.6019 - 238ms/epoch - 3ms/step
## Epoch 15/15
## 86/86 - 0s - loss: 454.2026 - accuracy: 0.4709 - val_loss: 174.0480 - val_accuracy: 0.3704 - 203ms/epoch - 2ms/step
# Predict on the test set with the multi-class model
predictions <- model %>% predict(as.matrix(test_x))
## 5/5 - 0s - 49ms/epoch - 10ms/step
# Convert predictions to class labels
predicted_classes <- apply(predictions, 1, which.max) - 1

# Calculate the confusion matrix and other evaluation metrics
confusion_matrix <- table(Predicted = predicted_classes, Actual = as.matrix(test_y))

# Accuracy
accuracy <- sum(diag(confusion_matrix)) / sum(confusion_matrix)

# For multi-class AUC, you need to calculate AUC for each class
auc_values <- sapply(1:ncol(test_y_onehot), function(class_index) {
  roc_obj <- pROC::roc(response = test_y_onehot[, class_index], predictor = predictions[, class_index])
  pROC::auc(roc_obj)
})
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
## Setting levels: control = 0, case = 1
## Setting direction: controls > cases
## Setting levels: control = 0, case = 1
## Setting direction: controls > cases
# Average AUC
mean_auc <- mean(auc_values)

# Display results
list(
  ConfusionMatrix = confusion_matrix,
  Accuracy = accuracy,
  AUC = mean_auc
)
## $ConfusionMatrix
##          Actual
## Predicted  0  1  2
##         0 41 48  0
##         1  0  0 46
## 
## $Accuracy
## [1] 0.3037037
## 
## $AUC
## [1] 0.6957535
plot(history)

summary(model)
## Model: "sequential_1"
## ________________________________________________________________________________
##  Layer (type)                       Output Shape                    Param #     
## ================================================================================
##  dense_3 (Dense)                    (None, 8)                       96          
##  dense_2 (Dense)                    (None, 3)                       27          
## ================================================================================
## Total params: 123
## Trainable params: 123
## Non-trainable params: 0
## ________________________________________________________________________________

Using the same training parameters ensures that the difference in performance is likely due to the increased complexity of the task rather than differences in the training process itself. As more classes are introduced, the model must learn to distinguish between a greater number of patterns, which typically makes the task more challenging and can lead to a decrease in performance metrics compared to a binary classification task.

Multi-Class Classification Model Results Confusion Matrix:

This is a 3x3 matrix, indicating three classes.

For ‘class 0’, the model predicted correctly 41 times, and there were no predictions for ‘class 2’ when ‘class 0’ was the actual class. However, it incorrectly predicted ‘class 1’ for 48 instances where ‘class 0’ was the actual class.

For ‘class 1’, there are no correct predictions, as the actual count for ‘class 1’ is not given, and it seems that all instances where ‘class 1’ was the actual class have been incorrectly predicted as ‘class 2’.

For ‘class 2’, the model correctly predicted 46 times, and there were no predictions for ‘class 0’ or ‘class 1’ when ‘class 2’ was the actual class.

The row for ‘class 2’ predictions is all zeros, which may indicate that there were no instances predicted as ‘class 2’, or the data for this row is missing.

Accuracy: The accuracy of the model is about 30.37% which is not perfect as in the binary case. This means that only 30.37% of the instances were correctly classified.

Area Under the Curve (AUC): The AUC value is 0.6957535, which is relatively higher than the accuracy and generally represents a good level of separability across the three classes.

The first dense layer, dense_3, has 8 neurons and presumably is connected to the input layer. Since the Param # is 96, this implies the input layer has 12 features (12 x 8 + 8).

The second dense layer, dense_2, has 3 neurons, corresponding to the 3 classes of the output with a total of 27 parameters indicate that it takes inputs from the 8 neurons of the previous layer (8 x 3 + 3).

In the context where the same number of epochs, batch size, and seed are used for training both the binary and multi-class classification models, it’s notable that the multi-class model performed inferiorly, with an accuracy of about 30.37% and an AUC of approximately 0.696. This is an expected outcome, as the complexity of the task tends to increase with the addition of more classes.

2 Regression

# Data Extraction
library(rvest)
library(dplyr)
## Warning: package 'dplyr' was built under R version 4.2.3
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
url <- "https://wiki.socr.umich.edu/index.php/SOCR_Data_Dinov_032708_AllometricPlanRels"
web_data <- read_html(url)
df <- web_data %>% html_table(fill = TRUE) %>% .[[1]]
head(df)
## # A tibble: 6 × 8
##   `Province/Sites` `Alt.(m)` `Long.(E,deg.)` `Lat.(N,deg.)` Born    `L(g/no.)`
##   <chr>                <int>           <dbl>          <dbl> <chr>        <dbl>
## 1 Heilongjiang           800            129.           44.3 natural     17538.
## 2 Heilongjiang           550            125.           52.3 natural      9313.
## 3 Heilongjiang           441            127.           51.7 natural      2570.
## 4 Heilongjiang           590            132.           46.5 natural     13939.
## 5 Heilongjiang           800            130.           44.1 natural     14375 
## 6 Heilongjiang           590            125.           51.4 natural      9017.
## # ℹ 2 more variables: `M(g/no.)` <dbl>, `D(no./m2)` <dbl>
# Preprocess the data

df <- df %>%
  mutate(across(where(is.character), ~factor(.x))) %>%
  mutate(across(where(is.factor), ~as.numeric(as.factor(.x))))

# Verify the structure
str(df)
## tibble [48 × 8] (S3: tbl_df/tbl/data.frame)
##  $ Province/Sites: num [1:48] 1 1 1 1 1 1 1 1 2 2 ...
##  $ Alt.(m)       : int [1:48] 800 550 441 590 800 590 876 500 880 900 ...
##  $ Long.(E,deg.) : num [1:48] 129 125 127 132 130 ...
##  $ Lat.(N,deg.)  : num [1:48] 44.3 52.3 51.7 46.5 44.1 ...
##  $ Born          : num [1:48] 1 1 1 1 1 1 1 1 1 1 ...
##  $ L(g/no.)      : num [1:48] 17538 9313 2570 13939 14375 ...
##  $ M(g/no.)      : num [1:48] 610990 298385 82175 422030 450643 ...
##  $ D(no./m2)     : num [1:48] 0.0394 0.0291 0.114 0.033 0.0544 ...
df <- df %>%
  rename(
    PS = `Province/Sites`,
    Altitude = `Alt.(m)`,
    Longitude = `Long.(E,deg.)`,
    Latitude = `Lat.(N,deg.)`,
    Born = Born,
    Length= `L(g/no.)`,
    Mass = `M(g/no.)`,
    Density = `D(no./m2)`
  )

2.1 Neural Net

# Split the data into training and testing sets
set.seed(123) # for reproducibility

train_indices <- sample(1:nrow(df), size = 0.7 * nrow(df))
train_data <- df[train_indices, ]
test_data <- df[-train_indices, ]

#install.packages("neuralnet")
library(neuralnet)
# Define the model
model <- neuralnet(Density ~ . , data = train_data, hidden = c(7,2), linear.output = FALSE)
plot(model,rep = "best")

The neural network is configured to predict Density using all available features in the train_data. It has a two-layered hidden structure with 7 neurons in the first layer and 2 neurons in the second layer. It does not output linear values, suggesting its use for classification. The error rate of 0.87 suggests the accuracy or fit of the best model (out of possibly many repetitions) to the training data. Whether this error rate is acceptable depends on the domain and the specific task. The training process took 217 steps to converge to the best solution, which suggests the complexity of finding the optimal weights for the given architecture and data.

2.2 RMSE

# Now separate the X and Y variables
features <- df %>%
  select(-Density) %>% # Exclude the response variable
  select(where(is.numeric)) %>%
  as.matrix()

response <- df$Density

train_indices <- sample(1:nrow(features), size = 0.7 * nrow(features))
train_features <- features[train_indices, ]
train_response <- response[train_indices]

test_features <- features[-train_indices, ]
test_response <- response[-train_indices]


# Define the model

model <- keras_model_sequential() %>%
  layer_dense(units = 256, activation = "relu") %>%
  layer_dense(units = 128, activation = "relu") %>%
  layer_dense(units = 64, activation = "relu") %>%
# layer_dense(units = 16, activation = "relu") %>%
  layer_dense(units = 1) 

# Compile the model
model %>% compile(
  loss = 'mse',
  optimizer =optimizer_rmsprop(), 
  metrics = list("mean_absolute_error"))

# Fit the model on the training data
history <- model %>% fit(
  as.matrix(train_features), as.matrix(train_response),
  epochs = 200,
  batch_size = 10,
  validation_split = 0.2
)
## Epoch 1/200
## 3/3 - 1s - loss: 2018269440.0000 - mean_absolute_error: 22177.6465 - val_loss: 19320790.0000 - val_mean_absolute_error: 3743.8501 - 731ms/epoch - 244ms/step
## Epoch 2/200
## 3/3 - 0s - loss: 81940584.0000 - mean_absolute_error: 5781.9619 - val_loss: 541667.3750 - val_mean_absolute_error: 612.5004 - 30ms/epoch - 10ms/step
## Epoch 3/200
## 3/3 - 0s - loss: 420525.3750 - mean_absolute_error: 378.3483 - val_loss: 16353.6895 - val_mean_absolute_error: 112.0617 - 31ms/epoch - 10ms/step
## Epoch 4/200
## 3/3 - 0s - loss: 53686.3828 - mean_absolute_error: 158.2263 - val_loss: 4885.2241 - val_mean_absolute_error: 60.9166 - 32ms/epoch - 11ms/step
## Epoch 5/200
## 3/3 - 0s - loss: 58311.3984 - mean_absolute_error: 96.0378 - val_loss: 1108321.8750 - val_mean_absolute_error: 909.3096 - 31ms/epoch - 10ms/step
## Epoch 6/200
## 3/3 - 0s - loss: 538943.4375 - mean_absolute_error: 429.0291 - val_loss: 279262.1562 - val_mean_absolute_error: 460.6824 - 32ms/epoch - 11ms/step
## Epoch 7/200
## 3/3 - 0s - loss: 136400.3281 - mean_absolute_error: 203.0796 - val_loss: 1333412.2500 - val_mean_absolute_error: 995.8418 - 30ms/epoch - 10ms/step
## Epoch 8/200
## 3/3 - 0s - loss: 27682180.0000 - mean_absolute_error: 3237.5864 - val_loss: 17711674.0000 - val_mean_absolute_error: 3588.7964 - 32ms/epoch - 11ms/step
## Epoch 9/200
## 3/3 - 0s - loss: 10073023.0000 - mean_absolute_error: 1971.4135 - val_loss: 378269.7188 - val_mean_absolute_error: 526.8839 - 32ms/epoch - 11ms/step
## Epoch 10/200
## 3/3 - 0s - loss: 6927231.5000 - mean_absolute_error: 1634.7599 - val_loss: 212869.8594 - val_mean_absolute_error: 390.4831 - 33ms/epoch - 11ms/step
## Epoch 11/200
## 3/3 - 0s - loss: 13129225.0000 - mean_absolute_error: 1915.7727 - val_loss: 117424000.0000 - val_mean_absolute_error: 9245.2441 - 34ms/epoch - 11ms/step
## Epoch 12/200
## 3/3 - 0s - loss: 558812672.0000 - mean_absolute_error: 13993.5732 - val_loss: 542833.8125 - val_mean_absolute_error: 650.6257 - 39ms/epoch - 13ms/step
## Epoch 13/200
## 3/3 - 0s - loss: 1684837.2500 - mean_absolute_error: 796.0423 - val_loss: 9776.2256 - val_mean_absolute_error: 81.7043 - 31ms/epoch - 10ms/step
## Epoch 14/200
## 3/3 - 0s - loss: 28271.7773 - mean_absolute_error: 111.4699 - val_loss: 7318.7710 - val_mean_absolute_error: 81.0536 - 32ms/epoch - 11ms/step
## Epoch 15/200
## 3/3 - 0s - loss: 6255.3887 - mean_absolute_error: 65.2097 - val_loss: 19603.3438 - val_mean_absolute_error: 133.5465 - 36ms/epoch - 12ms/step
## Epoch 16/200
## 3/3 - 0s - loss: 83476.8203 - mean_absolute_error: 187.1333 - val_loss: 2539.5422 - val_mean_absolute_error: 43.4944 - 36ms/epoch - 12ms/step
## Epoch 17/200
## 3/3 - 0s - loss: 2667.6514 - mean_absolute_error: 44.1748 - val_loss: 2941.7185 - val_mean_absolute_error: 46.2272 - 35ms/epoch - 12ms/step
## Epoch 18/200
## 3/3 - 0s - loss: 7185.3838 - mean_absolute_error: 60.9618 - val_loss: 1897.5573 - val_mean_absolute_error: 35.5266 - 34ms/epoch - 11ms/step
## Epoch 19/200
## 3/3 - 0s - loss: 27400.7930 - mean_absolute_error: 117.2747 - val_loss: 2352.3967 - val_mean_absolute_error: 39.3390 - 34ms/epoch - 11ms/step
## Epoch 20/200
## 3/3 - 0s - loss: 33816.3203 - mean_absolute_error: 103.3453 - val_loss: 538546.8750 - val_mean_absolute_error: 610.3983 - 32ms/epoch - 11ms/step
## Epoch 21/200
## 3/3 - 0s - loss: 33267608.0000 - mean_absolute_error: 3490.0210 - val_loss: 156138608.0000 - val_mean_absolute_error: 10675.3682 - 34ms/epoch - 11ms/step
## Epoch 22/200
## 3/3 - 0s - loss: 137009712.0000 - mean_absolute_error: 7531.5337 - val_loss: 271546624.0000 - val_mean_absolute_error: 14046.1436 - 33ms/epoch - 11ms/step
## Epoch 23/200
## 3/3 - 0s - loss: 151299984.0000 - mean_absolute_error: 8139.8408 - val_loss: 43819368.0000 - val_mean_absolute_error: 5633.1338 - 34ms/epoch - 11ms/step
## Epoch 24/200
## 3/3 - 0s - loss: 80227248.0000 - mean_absolute_error: 6075.7715 - val_loss: 421272.5000 - val_mean_absolute_error: 536.9413 - 34ms/epoch - 11ms/step
## Epoch 25/200
## 3/3 - 0s - loss: 339061.2188 - mean_absolute_error: 423.4252 - val_loss: 68973.0469 - val_mean_absolute_error: 205.8063 - 32ms/epoch - 11ms/step
## Epoch 26/200
## 3/3 - 0s - loss: 53296.1719 - mean_absolute_error: 170.2190 - val_loss: 1788.9784 - val_mean_absolute_error: 37.8697 - 33ms/epoch - 11ms/step
## Epoch 27/200
## 3/3 - 0s - loss: 3559.3306 - mean_absolute_error: 41.7822 - val_loss: 1289.1930 - val_mean_absolute_error: 33.4098 - 34ms/epoch - 11ms/step
## Epoch 28/200
## 3/3 - 0s - loss: 8887.9316 - mean_absolute_error: 50.1870 - val_loss: 154476.7500 - val_mean_absolute_error: 317.8371 - 33ms/epoch - 11ms/step
## Epoch 29/200
## 3/3 - 0s - loss: 63265.7461 - mean_absolute_error: 145.9257 - val_loss: 60971.0820 - val_mean_absolute_error: 224.6793 - 33ms/epoch - 11ms/step
## Epoch 30/200
## 3/3 - 0s - loss: 153932.9375 - mean_absolute_error: 272.0783 - val_loss: 2402967.2500 - val_mean_absolute_error: 1305.9403 - 33ms/epoch - 11ms/step
## Epoch 31/200
## 3/3 - 0s - loss: 13533996.0000 - mean_absolute_error: 2225.2485 - val_loss: 7977262.5000 - val_mean_absolute_error: 2427.6316 - 34ms/epoch - 11ms/step
## Epoch 32/200
## 3/3 - 0s - loss: 131131312.0000 - mean_absolute_error: 7583.3003 - val_loss: 24936084.0000 - val_mean_absolute_error: 4279.2773 - 36ms/epoch - 12ms/step
## Epoch 33/200
## 3/3 - 0s - loss: 82842440.0000 - mean_absolute_error: 5778.0786 - val_loss: 171484.3906 - val_mean_absolute_error: 369.3508 - 33ms/epoch - 11ms/step
## Epoch 34/200
## 3/3 - 0s - loss: 167384.1875 - mean_absolute_error: 303.4387 - val_loss: 78062.3672 - val_mean_absolute_error: 218.7873 - 43ms/epoch - 14ms/step
## Epoch 35/200
## 3/3 - 0s - loss: 239421.1406 - mean_absolute_error: 313.0955 - val_loss: 1579.3732 - val_mean_absolute_error: 35.3739 - 33ms/epoch - 11ms/step
## Epoch 36/200
## 3/3 - 0s - loss: 2630.9927 - mean_absolute_error: 41.4062 - val_loss: 2226.5078 - val_mean_absolute_error: 41.4763 - 34ms/epoch - 11ms/step
## Epoch 37/200
## 3/3 - 0s - loss: 2470.5667 - mean_absolute_error: 41.6119 - val_loss: 1694.5486 - val_mean_absolute_error: 36.7925 - 32ms/epoch - 11ms/step
## Epoch 38/200
## 3/3 - 0s - loss: 13412.1826 - mean_absolute_error: 76.4950 - val_loss: 10037.0137 - val_mean_absolute_error: 72.8492 - 32ms/epoch - 11ms/step
## Epoch 39/200
## 3/3 - 0s - loss: 59829.0391 - mean_absolute_error: 145.3856 - val_loss: 3701.6887 - val_mean_absolute_error: 58.1323 - 32ms/epoch - 11ms/step
## Epoch 40/200
## 3/3 - 0s - loss: 9159.4854 - mean_absolute_error: 60.3409 - val_loss: 540286.5625 - val_mean_absolute_error: 612.2853 - 31ms/epoch - 10ms/step
## Epoch 41/200
## 3/3 - 0s - loss: 26429024.0000 - mean_absolute_error: 2599.9722 - val_loss: 175104480.0000 - val_mean_absolute_error: 11307.8828 - 32ms/epoch - 11ms/step
## Epoch 42/200
## 3/3 - 0s - loss: 335284384.0000 - mean_absolute_error: 9616.6221 - val_loss: 1261259.3750 - val_mean_absolute_error: 939.9443 - 31ms/epoch - 10ms/step
## Epoch 43/200
## 3/3 - 0s - loss: 1009645.6875 - mean_absolute_error: 652.6781 - val_loss: 11740.5420 - val_mean_absolute_error: 103.9871 - 31ms/epoch - 10ms/step
## Epoch 44/200
## 3/3 - 0s - loss: 9178.7012 - mean_absolute_error: 69.1061 - val_loss: 2743.8745 - val_mean_absolute_error: 47.3708 - 33ms/epoch - 11ms/step
## Epoch 45/200
## 3/3 - 0s - loss: 1828.2194 - mean_absolute_error: 38.1877 - val_loss: 1964.4827 - val_mean_absolute_error: 38.1303 - 32ms/epoch - 11ms/step
## Epoch 46/200
## 3/3 - 0s - loss: 3021.6558 - mean_absolute_error: 42.4684 - val_loss: 1528.6154 - val_mean_absolute_error: 34.9692 - 36ms/epoch - 12ms/step
## Epoch 47/200
## 3/3 - 0s - loss: 2329.2778 - mean_absolute_error: 38.0774 - val_loss: 36188.1445 - val_mean_absolute_error: 150.5165 - 30ms/epoch - 10ms/step
## Epoch 48/200
## 3/3 - 0s - loss: 537086.6250 - mean_absolute_error: 508.8421 - val_loss: 646195.9375 - val_mean_absolute_error: 701.8254 - 30ms/epoch - 10ms/step
## Epoch 49/200
## 3/3 - 0s - loss: 9328359.0000 - mean_absolute_error: 2050.6536 - val_loss: 8073240.0000 - val_mean_absolute_error: 2441.4158 - 33ms/epoch - 11ms/step
## Epoch 50/200
## 3/3 - 0s - loss: 13089812.0000 - mean_absolute_error: 2482.0381 - val_loss: 218663.0781 - val_mean_absolute_error: 384.3835 - 32ms/epoch - 11ms/step
## Epoch 51/200
## 3/3 - 0s - loss: 2962952.5000 - mean_absolute_error: 1125.8254 - val_loss: 513172.1562 - val_mean_absolute_error: 596.6017 - 34ms/epoch - 11ms/step
## Epoch 52/200
## 3/3 - 0s - loss: 3145223.5000 - mean_absolute_error: 1027.4364 - val_loss: 24655.3945 - val_mean_absolute_error: 115.7699 - 45ms/epoch - 15ms/step
## Epoch 53/200
## 3/3 - 0s - loss: 46877.7539 - mean_absolute_error: 131.2449 - val_loss: 358488.5625 - val_mean_absolute_error: 526.1915 - 32ms/epoch - 11ms/step
## Epoch 54/200
## 3/3 - 0s - loss: 2535650.0000 - mean_absolute_error: 838.7695 - val_loss: 92304952.0000 - val_mean_absolute_error: 8181.4019 - 41ms/epoch - 14ms/step
## Epoch 55/200
## 3/3 - 0s - loss: 153583280.0000 - mean_absolute_error: 8465.3076 - val_loss: 32719.2148 - val_mean_absolute_error: 134.6143 - 35ms/epoch - 12ms/step
## Epoch 56/200
## 3/3 - 0s - loss: 101749.2500 - mean_absolute_error: 224.0085 - val_loss: 134643.0781 - val_mean_absolute_error: 326.8233 - 30ms/epoch - 10ms/step
## Epoch 57/200
## 3/3 - 0s - loss: 1137703.5000 - mean_absolute_error: 743.5020 - val_loss: 5621.4644 - val_mean_absolute_error: 56.5170 - 37ms/epoch - 12ms/step
## Epoch 58/200
## 3/3 - 0s - loss: 31622.2832 - mean_absolute_error: 92.0486 - val_loss: 416698.0000 - val_mean_absolute_error: 538.0172 - 36ms/epoch - 12ms/step
## Epoch 59/200
## 3/3 - 0s - loss: 2595974.2500 - mean_absolute_error: 1108.1090 - val_loss: 60489.6055 - val_mean_absolute_error: 219.5775 - 36ms/epoch - 12ms/step
## Epoch 60/200
## 3/3 - 0s - loss: 2171765.7500 - mean_absolute_error: 932.6613 - val_loss: 4797.9722 - val_mean_absolute_error: 52.7236 - 39ms/epoch - 13ms/step
## Epoch 61/200
## 3/3 - 0s - loss: 12954.5361 - mean_absolute_error: 75.0471 - val_loss: 203219.3594 - val_mean_absolute_error: 371.1813 - 39ms/epoch - 13ms/step
## Epoch 62/200
## 3/3 - 0s - loss: 151501.2969 - mean_absolute_error: 210.9671 - val_loss: 6151448.5000 - val_mean_absolute_error: 2103.2278 - 36ms/epoch - 12ms/step
## Epoch 63/200
## 3/3 - 0s - loss: 52609704.0000 - mean_absolute_error: 4104.1826 - val_loss: 60516196.0000 - val_mean_absolute_error: 6645.6826 - 39ms/epoch - 13ms/step
## Epoch 64/200
## 3/3 - 0s - loss: 118148256.0000 - mean_absolute_error: 5852.3223 - val_loss: 153369.1719 - val_mean_absolute_error: 324.0020 - 39ms/epoch - 13ms/step
## Epoch 65/200
## 3/3 - 0s - loss: 212744.6094 - mean_absolute_error: 284.8939 - val_loss: 1139.6920 - val_mean_absolute_error: 28.2405 - 35ms/epoch - 12ms/step
## Epoch 66/200
## 3/3 - 0s - loss: 18994.6133 - mean_absolute_error: 74.8744 - val_loss: 12831.0654 - val_mean_absolute_error: 84.6988 - 31ms/epoch - 10ms/step
## Epoch 67/200
## 3/3 - 0s - loss: 4378.7480 - mean_absolute_error: 37.4891 - val_loss: 898.2761 - val_mean_absolute_error: 22.0822 - 31ms/epoch - 10ms/step
## Epoch 68/200
## 3/3 - 0s - loss: 62395.3203 - mean_absolute_error: 143.6369 - val_loss: 3025.7419 - val_mean_absolute_error: 41.3140 - 34ms/epoch - 11ms/step
## Epoch 69/200
## 3/3 - 0s - loss: 23052.8027 - mean_absolute_error: 85.8221 - val_loss: 10799.8389 - val_mean_absolute_error: 95.0174 - 30ms/epoch - 10ms/step
## Epoch 70/200
## 3/3 - 0s - loss: 192267.9688 - mean_absolute_error: 261.1484 - val_loss: 1507.6796 - val_mean_absolute_error: 33.5813 - 31ms/epoch - 10ms/step
## Epoch 71/200
## 3/3 - 0s - loss: 186815.0781 - mean_absolute_error: 276.8752 - val_loss: 394451.9375 - val_mean_absolute_error: 525.7282 - 33ms/epoch - 11ms/step
## Epoch 72/200
## 3/3 - 0s - loss: 132560.9219 - mean_absolute_error: 175.5919 - val_loss: 39011.6055 - val_mean_absolute_error: 176.7799 - 32ms/epoch - 11ms/step
## Epoch 73/200
## 3/3 - 0s - loss: 39941.3281 - mean_absolute_error: 108.4675 - val_loss: 2248049.0000 - val_mean_absolute_error: 1263.6619 - 30ms/epoch - 10ms/step
## Epoch 74/200
## 3/3 - 0s - loss: 22744330.0000 - mean_absolute_error: 3119.4399 - val_loss: 19607658.0000 - val_mean_absolute_error: 3777.7512 - 36ms/epoch - 12ms/step
## Epoch 75/200
## 3/3 - 0s - loss: 87996248.0000 - mean_absolute_error: 5791.5669 - val_loss: 2148677.5000 - val_mean_absolute_error: 1241.1272 - 34ms/epoch - 11ms/step
## Epoch 76/200
## 3/3 - 0s - loss: 2065985.3750 - mean_absolute_error: 926.6773 - val_loss: 12976.6318 - val_mean_absolute_error: 105.3063 - 33ms/epoch - 11ms/step
## Epoch 77/200
## 3/3 - 0s - loss: 14628.1699 - mean_absolute_error: 88.9185 - val_loss: 32803.6523 - val_mean_absolute_error: 142.9313 - 31ms/epoch - 10ms/step
## Epoch 78/200
## 3/3 - 0s - loss: 40334.0781 - mean_absolute_error: 130.8246 - val_loss: 179863.0781 - val_mean_absolute_error: 371.7028 - 31ms/epoch - 10ms/step
## Epoch 79/200
## 3/3 - 0s - loss: 3286009.0000 - mean_absolute_error: 1184.0558 - val_loss: 1872694.3750 - val_mean_absolute_error: 1179.7385 - 32ms/epoch - 11ms/step
## Epoch 80/200
## 3/3 - 0s - loss: 14118857.0000 - mean_absolute_error: 2438.2290 - val_loss: 480.7404 - val_mean_absolute_error: 18.7436 - 32ms/epoch - 11ms/step
## Epoch 81/200
## 3/3 - 0s - loss: 1286.7532 - mean_absolute_error: 21.5733 - val_loss: 23828.2285 - val_mean_absolute_error: 122.2422 - 31ms/epoch - 10ms/step
## Epoch 82/200
## 3/3 - 0s - loss: 24384.1387 - mean_absolute_error: 97.0255 - val_loss: 841.4240 - val_mean_absolute_error: 23.9780 - 31ms/epoch - 10ms/step
## Epoch 83/200
## 3/3 - 0s - loss: 51228.5977 - mean_absolute_error: 118.8286 - val_loss: 22113.5039 - val_mean_absolute_error: 117.0461 - 31ms/epoch - 10ms/step
## Epoch 84/200
## 3/3 - 0s - loss: 607789.7500 - mean_absolute_error: 446.8210 - val_loss: 1142107.0000 - val_mean_absolute_error: 902.0145 - 32ms/epoch - 11ms/step
## Epoch 85/200
## 3/3 - 0s - loss: 25698574.0000 - mean_absolute_error: 3298.6118 - val_loss: 3318258.2500 - val_mean_absolute_error: 1546.1804 - 30ms/epoch - 10ms/step
## Epoch 86/200
## 3/3 - 0s - loss: 17519016.0000 - mean_absolute_error: 2747.0378 - val_loss: 292000.4375 - val_mean_absolute_error: 467.0139 - 34ms/epoch - 11ms/step
## Epoch 87/200
## 3/3 - 0s - loss: 179039.9375 - mean_absolute_error: 260.4535 - val_loss: 3622.5886 - val_mean_absolute_error: 44.2102 - 31ms/epoch - 10ms/step
## Epoch 88/200
## 3/3 - 0s - loss: 6280.5166 - mean_absolute_error: 37.8223 - val_loss: 71546.2969 - val_mean_absolute_error: 221.4520 - 33ms/epoch - 11ms/step
## Epoch 89/200
## 3/3 - 0s - loss: 436953.2812 - mean_absolute_error: 452.0117 - val_loss: 29754.5488 - val_mean_absolute_error: 153.0072 - 30ms/epoch - 10ms/step
## Epoch 90/200
## 3/3 - 0s - loss: 718396.6250 - mean_absolute_error: 572.8501 - val_loss: 256462.1406 - val_mean_absolute_error: 437.9351 - 31ms/epoch - 10ms/step
## Epoch 91/200
## 3/3 - 0s - loss: 318797.8125 - mean_absolute_error: 343.6335 - val_loss: 5266094.5000 - val_mean_absolute_error: 1951.3035 - 31ms/epoch - 10ms/step
## Epoch 92/200
## 3/3 - 0s - loss: 30799170.0000 - mean_absolute_error: 3511.8857 - val_loss: 67986.2500 - val_mean_absolute_error: 227.7242 - 30ms/epoch - 10ms/step
## Epoch 93/200
## 3/3 - 0s - loss: 42616.9375 - mean_absolute_error: 137.6563 - val_loss: 23887.8477 - val_mean_absolute_error: 136.8813 - 31ms/epoch - 10ms/step
## Epoch 94/200
## 3/3 - 0s - loss: 24457.0449 - mean_absolute_error: 109.4135 - val_loss: 18395.9141 - val_mean_absolute_error: 109.7767 - 30ms/epoch - 10ms/step
## Epoch 95/200
## 3/3 - 0s - loss: 238750.5938 - mean_absolute_error: 320.8312 - val_loss: 13683.7412 - val_mean_absolute_error: 93.7336 - 33ms/epoch - 11ms/step
## Epoch 96/200
## 3/3 - 0s - loss: 228422.9062 - mean_absolute_error: 307.6782 - val_loss: 1174.6683 - val_mean_absolute_error: 33.0905 - 30ms/epoch - 10ms/step
## Epoch 97/200
## 3/3 - 0s - loss: 588.8242 - mean_absolute_error: 18.8131 - val_loss: 5311.0234 - val_mean_absolute_error: 55.8430 - 31ms/epoch - 10ms/step
## Epoch 98/200
## 3/3 - 0s - loss: 189273.0000 - mean_absolute_error: 278.3611 - val_loss: 319843.7500 - val_mean_absolute_error: 476.3712 - 29ms/epoch - 10ms/step
## Epoch 99/200
## 3/3 - 0s - loss: 1318100.7500 - mean_absolute_error: 624.2157 - val_loss: 31753366.0000 - val_mean_absolute_error: 4809.6646 - 29ms/epoch - 10ms/step
## Epoch 100/200
## 3/3 - 0s - loss: 30514254.0000 - mean_absolute_error: 3658.4478 - val_loss: 84541.1094 - val_mean_absolute_error: 250.1991 - 35ms/epoch - 12ms/step
## Epoch 101/200
## 3/3 - 0s - loss: 250885.2188 - mean_absolute_error: 311.9144 - val_loss: 6944141.5000 - val_mean_absolute_error: 2252.7327 - 29ms/epoch - 10ms/step
## Epoch 102/200
## 3/3 - 0s - loss: 19719768.0000 - mean_absolute_error: 3020.4045 - val_loss: 2167203.2500 - val_mean_absolute_error: 1256.7057 - 30ms/epoch - 10ms/step
## Epoch 103/200
## 3/3 - 0s - loss: 3201344.5000 - mean_absolute_error: 1230.4966 - val_loss: 227475.4688 - val_mean_absolute_error: 404.6143 - 29ms/epoch - 10ms/step
## Epoch 104/200
## 3/3 - 0s - loss: 126957.7109 - mean_absolute_error: 228.2343 - val_loss: 304632.5625 - val_mean_absolute_error: 468.4704 - 34ms/epoch - 11ms/step
## Epoch 105/200
## 3/3 - 0s - loss: 468827.9375 - mean_absolute_error: 446.0646 - val_loss: 164239.7188 - val_mean_absolute_error: 347.5965 - 32ms/epoch - 11ms/step
## Epoch 106/200
## 3/3 - 0s - loss: 404883.7188 - mean_absolute_error: 412.0231 - val_loss: 249431.7500 - val_mean_absolute_error: 428.1185 - 39ms/epoch - 13ms/step
## Epoch 107/200
## 3/3 - 0s - loss: 852685.1250 - mean_absolute_error: 631.8765 - val_loss: 1696.3633 - val_mean_absolute_error: 33.6871 - 33ms/epoch - 11ms/step
## Epoch 108/200
## 3/3 - 0s - loss: 30260.6328 - mean_absolute_error: 88.6826 - val_loss: 68908.2969 - val_mean_absolute_error: 224.5719 - 30ms/epoch - 10ms/step
## Epoch 109/200
## 3/3 - 0s - loss: 2290737.5000 - mean_absolute_error: 1006.7604 - val_loss: 409285.0000 - val_mean_absolute_error: 547.3334 - 32ms/epoch - 11ms/step
## Epoch 110/200
## 3/3 - 0s - loss: 1412532.0000 - mean_absolute_error: 878.8727 - val_loss: 6883019.5000 - val_mean_absolute_error: 2234.8013 - 31ms/epoch - 10ms/step
## Epoch 111/200
## 3/3 - 0s - loss: 25474850.0000 - mean_absolute_error: 3365.6807 - val_loss: 990808.5625 - val_mean_absolute_error: 848.2744 - 35ms/epoch - 12ms/step
## Epoch 112/200
## 3/3 - 0s - loss: 2909528.0000 - mean_absolute_error: 1077.2312 - val_loss: 22265.8613 - val_mean_absolute_error: 124.6115 - 33ms/epoch - 11ms/step
## Epoch 113/200
## 3/3 - 0s - loss: 196962.7656 - mean_absolute_error: 299.8185 - val_loss: 722.6653 - val_mean_absolute_error: 16.8054 - 36ms/epoch - 12ms/step
## Epoch 114/200
## 3/3 - 0s - loss: 38334.3281 - mean_absolute_error: 112.4606 - val_loss: 3395.5334 - val_mean_absolute_error: 47.4295 - 32ms/epoch - 11ms/step
## Epoch 115/200
## 3/3 - 0s - loss: 31961.6758 - mean_absolute_error: 106.6673 - val_loss: 25176.9102 - val_mean_absolute_error: 133.5696 - 35ms/epoch - 12ms/step
## Epoch 116/200
## 3/3 - 0s - loss: 361894.3125 - mean_absolute_error: 382.7624 - val_loss: 3150.1011 - val_mean_absolute_error: 45.7732 - 32ms/epoch - 11ms/step
## Epoch 117/200
## 3/3 - 0s - loss: 11032.8213 - mean_absolute_error: 46.2494 - val_loss: 286396.5625 - val_mean_absolute_error: 455.3769 - 35ms/epoch - 12ms/step
## Epoch 118/200
## 3/3 - 0s - loss: 4074568.0000 - mean_absolute_error: 1239.4202 - val_loss: 87950.2578 - val_mean_absolute_error: 252.1011 - 36ms/epoch - 12ms/step
## Epoch 119/200
## 3/3 - 0s - loss: 343228.4375 - mean_absolute_error: 354.1418 - val_loss: 385445.0625 - val_mean_absolute_error: 533.0853 - 35ms/epoch - 12ms/step
## Epoch 120/200
## 3/3 - 0s - loss: 773345.0625 - mean_absolute_error: 587.4399 - val_loss: 68679.8750 - val_mean_absolute_error: 218.9030 - 32ms/epoch - 11ms/step
## Epoch 121/200
## 3/3 - 0s - loss: 956748.0000 - mean_absolute_error: 672.0156 - val_loss: 451270.4375 - val_mean_absolute_error: 577.6586 - 33ms/epoch - 11ms/step
## Epoch 122/200
## 3/3 - 0s - loss: 4663699.0000 - mean_absolute_error: 1507.8859 - val_loss: 1419493.6250 - val_mean_absolute_error: 1014.4671 - 30ms/epoch - 10ms/step
## Epoch 123/200
## 3/3 - 0s - loss: 5992500.5000 - mean_absolute_error: 1631.0907 - val_loss: 430.0992 - val_mean_absolute_error: 14.4033 - 31ms/epoch - 10ms/step
## Epoch 124/200
## 3/3 - 0s - loss: 17791.3828 - mean_absolute_error: 83.1656 - val_loss: 726.9236 - val_mean_absolute_error: 24.7913 - 31ms/epoch - 10ms/step
## Epoch 125/200
## 3/3 - 0s - loss: 5056.4185 - mean_absolute_error: 33.7889 - val_loss: 91524.8203 - val_mean_absolute_error: 258.2952 - 30ms/epoch - 10ms/step
## Epoch 126/200
## 3/3 - 0s - loss: 354237.5000 - mean_absolute_error: 366.8130 - val_loss: 467.4528 - val_mean_absolute_error: 15.9974 - 33ms/epoch - 11ms/step
## Epoch 127/200
## 3/3 - 0s - loss: 5993.3979 - mean_absolute_error: 28.4473 - val_loss: 166245.8750 - val_mean_absolute_error: 349.1440 - 30ms/epoch - 10ms/step
## Epoch 128/200
## 3/3 - 0s - loss: 1577939.1250 - mean_absolute_error: 868.7360 - val_loss: 982245.6250 - val_mean_absolute_error: 841.0688 - 32ms/epoch - 11ms/step
## Epoch 129/200
## 3/3 - 0s - loss: 897370.7500 - mean_absolute_error: 655.7756 - val_loss: 229564.4688 - val_mean_absolute_error: 403.7808 - 30ms/epoch - 10ms/step
## Epoch 130/200
## 3/3 - 0s - loss: 237460.9062 - mean_absolute_error: 341.0848 - val_loss: 599003.4375 - val_mean_absolute_error: 655.3722 - 31ms/epoch - 10ms/step
## Epoch 131/200
## 3/3 - 0s - loss: 5432268.5000 - mean_absolute_error: 1429.5894 - val_loss: 137.6288 - val_mean_absolute_error: 7.0437 - 32ms/epoch - 11ms/step
## Epoch 132/200
## 3/3 - 0s - loss: 175554.2344 - mean_absolute_error: 255.8757 - val_loss: 376282.5938 - val_mean_absolute_error: 525.5138 - 30ms/epoch - 10ms/step
## Epoch 133/200
## 3/3 - 0s - loss: 1446243.7500 - mean_absolute_error: 799.0149 - val_loss: 425756.7188 - val_mean_absolute_error: 560.5094 - 32ms/epoch - 11ms/step
## Epoch 134/200
## 3/3 - 0s - loss: 320884.7188 - mean_absolute_error: 403.3936 - val_loss: 13851.1074 - val_mean_absolute_error: 97.5690 - 31ms/epoch - 10ms/step
## Epoch 135/200
## 3/3 - 0s - loss: 26910.2676 - mean_absolute_error: 106.1822 - val_loss: 116.1909 - val_mean_absolute_error: 7.6214 - 31ms/epoch - 10ms/step
## Epoch 136/200
## 3/3 - 0s - loss: 9848.1035 - mean_absolute_error: 61.9568 - val_loss: 187.4280 - val_mean_absolute_error: 11.3721 - 47ms/epoch - 16ms/step
## Epoch 137/200
## 3/3 - 0s - loss: 158293.2344 - mean_absolute_error: 186.6511 - val_loss: 1304788.8750 - val_mean_absolute_error: 973.9418 - 32ms/epoch - 11ms/step
## Epoch 138/200
## 3/3 - 0s - loss: 1808835.7500 - mean_absolute_error: 902.4971 - val_loss: 559732.0625 - val_mean_absolute_error: 643.6171 - 30ms/epoch - 10ms/step
## Epoch 139/200
## 3/3 - 0s - loss: 354059.7188 - mean_absolute_error: 345.5819 - val_loss: 1455.9567 - val_mean_absolute_error: 32.4863 - 29ms/epoch - 10ms/step
## Epoch 140/200
## 3/3 - 0s - loss: 1252.9559 - mean_absolute_error: 23.3614 - val_loss: 91.0481 - val_mean_absolute_error: 8.6962 - 32ms/epoch - 11ms/step
## Epoch 141/200
## 3/3 - 0s - loss: 509.8655 - mean_absolute_error: 14.8420 - val_loss: 42.7837 - val_mean_absolute_error: 6.0504 - 37ms/epoch - 12ms/step
## Epoch 142/200
## 3/3 - 0s - loss: 1818.5392 - mean_absolute_error: 22.9417 - val_loss: 4183.9619 - val_mean_absolute_error: 55.2590 - 33ms/epoch - 11ms/step
## Epoch 143/200
## 3/3 - 0s - loss: 9175.0332 - mean_absolute_error: 70.0252 - val_loss: 390.5165 - val_mean_absolute_error: 16.2214 - 33ms/epoch - 11ms/step
## Epoch 144/200
## 3/3 - 0s - loss: 833.9427 - mean_absolute_error: 20.5952 - val_loss: 52.4098 - val_mean_absolute_error: 6.6154 - 34ms/epoch - 11ms/step
## Epoch 145/200
## 3/3 - 0s - loss: 5005.1592 - mean_absolute_error: 31.6098 - val_loss: 19844.3848 - val_mean_absolute_error: 120.1162 - 33ms/epoch - 11ms/step
## Epoch 146/200
## 3/3 - 0s - loss: 42157.9297 - mean_absolute_error: 130.1169 - val_loss: 576586.0000 - val_mean_absolute_error: 647.2552 - 36ms/epoch - 12ms/step
## Epoch 147/200
## 3/3 - 0s - loss: 1487011.7500 - mean_absolute_error: 745.7644 - val_loss: 93213.2500 - val_mean_absolute_error: 260.9442 - 32ms/epoch - 11ms/step
## Epoch 148/200
## 3/3 - 0s - loss: 61970.7656 - mean_absolute_error: 116.1210 - val_loss: 1334.2076 - val_mean_absolute_error: 29.7729 - 33ms/epoch - 11ms/step
## Epoch 149/200
## 3/3 - 0s - loss: 3055.3159 - mean_absolute_error: 36.7403 - val_loss: 6.8815 - val_mean_absolute_error: 2.4467 - 34ms/epoch - 11ms/step
## Epoch 150/200
## 3/3 - 0s - loss: 35.9567 - mean_absolute_error: 3.0652 - val_loss: 379.9266 - val_mean_absolute_error: 15.2643 - 35ms/epoch - 12ms/step
## Epoch 151/200
## 3/3 - 0s - loss: 1782.0988 - mean_absolute_error: 27.5929 - val_loss: 11.5144 - val_mean_absolute_error: 2.8848 - 33ms/epoch - 11ms/step
## Epoch 152/200
## 3/3 - 0s - loss: 23.7713 - mean_absolute_error: 2.9590 - val_loss: 6.7567 - val_mean_absolute_error: 2.4171 - 33ms/epoch - 11ms/step
## Epoch 153/200
## 3/3 - 0s - loss: 234.0905 - mean_absolute_error: 9.8754 - val_loss: 102.8718 - val_mean_absolute_error: 7.3147 - 30ms/epoch - 10ms/step
## Epoch 154/200
## 3/3 - 0s - loss: 531.2399 - mean_absolute_error: 16.2913 - val_loss: 1470.8722 - val_mean_absolute_error: 33.6668 - 30ms/epoch - 10ms/step
## Epoch 155/200
## 3/3 - 0s - loss: 12457.1816 - mean_absolute_error: 61.9637 - val_loss: 211658.5781 - val_mean_absolute_error: 393.5392 - 31ms/epoch - 10ms/step
## Epoch 156/200
## 3/3 - 0s - loss: 1907337.3750 - mean_absolute_error: 925.4459 - val_loss: 389233.0000 - val_mean_absolute_error: 530.5489 - 30ms/epoch - 10ms/step
## Epoch 157/200
## 3/3 - 0s - loss: 221213.4844 - mean_absolute_error: 227.9880 - val_loss: 10.1331 - val_mean_absolute_error: 2.2994 - 32ms/epoch - 11ms/step
## Epoch 158/200
## 3/3 - 0s - loss: 1381.6731 - mean_absolute_error: 20.9561 - val_loss: 275.7252 - val_mean_absolute_error: 14.4810 - 33ms/epoch - 11ms/step
## Epoch 159/200
## 3/3 - 0s - loss: 1689.3594 - mean_absolute_error: 28.9907 - val_loss: 1448.9662 - val_mean_absolute_error: 31.7639 - 32ms/epoch - 11ms/step
## Epoch 160/200
## 3/3 - 0s - loss: 18160.9473 - mean_absolute_error: 91.7535 - val_loss: 14302.5801 - val_mean_absolute_error: 102.2473 - 31ms/epoch - 10ms/step
## Epoch 161/200
## 3/3 - 0s - loss: 229392.5000 - mean_absolute_error: 283.3499 - val_loss: 47442.5547 - val_mean_absolute_error: 184.7773 - 33ms/epoch - 11ms/step
## Epoch 162/200
## 3/3 - 0s - loss: 597554.3750 - mean_absolute_error: 509.5569 - val_loss: 177605.3594 - val_mean_absolute_error: 357.1114 - 32ms/epoch - 11ms/step
## Epoch 163/200
## 3/3 - 0s - loss: 441195.3438 - mean_absolute_error: 429.1706 - val_loss: 180504.2656 - val_mean_absolute_error: 364.0522 - 32ms/epoch - 11ms/step
## Epoch 164/200
## 3/3 - 0s - loss: 963219.8750 - mean_absolute_error: 605.6765 - val_loss: 504390.6562 - val_mean_absolute_error: 604.7960 - 32ms/epoch - 11ms/step
## Epoch 165/200
## 3/3 - 0s - loss: 1264254.6250 - mean_absolute_error: 674.0167 - val_loss: 111704.1875 - val_mean_absolute_error: 284.3521 - 30ms/epoch - 10ms/step
## Epoch 166/200
## 3/3 - 0s - loss: 203982.6719 - mean_absolute_error: 219.2178 - val_loss: 5497.4868 - val_mean_absolute_error: 63.5537 - 33ms/epoch - 11ms/step
## Epoch 167/200
## 3/3 - 0s - loss: 10977.6191 - mean_absolute_error: 53.7335 - val_loss: 8.0237 - val_mean_absolute_error: 2.2376 - 32ms/epoch - 11ms/step
## Epoch 168/200
## 3/3 - 0s - loss: 8.8539 - mean_absolute_error: 1.7761 - val_loss: 1.9708 - val_mean_absolute_error: 1.1796 - 31ms/epoch - 10ms/step
## Epoch 169/200
## 3/3 - 0s - loss: 7.0644 - mean_absolute_error: 1.9370 - val_loss: 2.9183 - val_mean_absolute_error: 1.2682 - 31ms/epoch - 10ms/step
## Epoch 170/200
## 3/3 - 0s - loss: 8.0536 - mean_absolute_error: 1.7811 - val_loss: 2.0759 - val_mean_absolute_error: 1.2561 - 30ms/epoch - 10ms/step
## Epoch 171/200
## 3/3 - 0s - loss: 2.9380 - mean_absolute_error: 1.4808 - val_loss: 2.1530 - val_mean_absolute_error: 1.2868 - 32ms/epoch - 11ms/step
## Epoch 172/200
## 3/3 - 0s - loss: 2.5209 - mean_absolute_error: 1.3521 - val_loss: 2.0810 - val_mean_absolute_error: 1.2635 - 31ms/epoch - 10ms/step
## Epoch 173/200
## 3/3 - 0s - loss: 2.5089 - mean_absolute_error: 1.3470 - val_loss: 2.9268 - val_mean_absolute_error: 1.3856 - 32ms/epoch - 11ms/step
## Epoch 174/200
## 3/3 - 0s - loss: 3.8935 - mean_absolute_error: 1.5365 - val_loss: 2.5146 - val_mean_absolute_error: 1.3422 - 30ms/epoch - 10ms/step
## Epoch 175/200
## 3/3 - 0s - loss: 3.3663 - mean_absolute_error: 1.4674 - val_loss: 7.1548 - val_mean_absolute_error: 2.3942 - 31ms/epoch - 10ms/step
## Epoch 176/200
## 3/3 - 0s - loss: 77.6947 - mean_absolute_error: 6.1288 - val_loss: 70.7006 - val_mean_absolute_error: 6.5913 - 39ms/epoch - 13ms/step
## Epoch 177/200
## 3/3 - 0s - loss: 1018.6829 - mean_absolute_error: 19.9167 - val_loss: 60.3742 - val_mean_absolute_error: 6.9452 - 41ms/epoch - 14ms/step
## Epoch 178/200
## 3/3 - 0s - loss: 20.1124 - mean_absolute_error: 2.9692 - val_loss: 58.3268 - val_mean_absolute_error: 6.8251 - 36ms/epoch - 12ms/step
## Epoch 179/200
## 3/3 - 0s - loss: 2751.5171 - mean_absolute_error: 31.1079 - val_loss: 281.9444 - val_mean_absolute_error: 14.5143 - 38ms/epoch - 13ms/step
## Epoch 180/200
## 3/3 - 0s - loss: 45005.2148 - mean_absolute_error: 129.5238 - val_loss: 56452.9102 - val_mean_absolute_error: 201.4269 - 44ms/epoch - 15ms/step
## Epoch 181/200
## 3/3 - 0s - loss: 784861.6875 - mean_absolute_error: 561.1381 - val_loss: 258325.4688 - val_mean_absolute_error: 434.5649 - 42ms/epoch - 14ms/step
## Epoch 182/200
## 3/3 - 0s - loss: 1097230.2500 - mean_absolute_error: 687.8195 - val_loss: 17444.1426 - val_mean_absolute_error: 111.0428 - 42ms/epoch - 14ms/step
## Epoch 183/200
## 3/3 - 0s - loss: 20123.1934 - mean_absolute_error: 102.1124 - val_loss: 2204.1211 - val_mean_absolute_error: 41.0962 - 37ms/epoch - 12ms/step
## Epoch 184/200
## 3/3 - 0s - loss: 3876.1802 - mean_absolute_error: 42.1887 - val_loss: 17462.6387 - val_mean_absolute_error: 116.1515 - 34ms/epoch - 11ms/step
## Epoch 185/200
## 3/3 - 0s - loss: 12884.5850 - mean_absolute_error: 71.4345 - val_loss: 8.2776 - val_mean_absolute_error: 2.5920 - 31ms/epoch - 10ms/step
## Epoch 186/200
## 3/3 - 0s - loss: 40.9125 - mean_absolute_error: 3.8916 - val_loss: 7.5745 - val_mean_absolute_error: 2.5653 - 32ms/epoch - 11ms/step
## Epoch 187/200
## 3/3 - 0s - loss: 28.7258 - mean_absolute_error: 3.4560 - val_loss: 15.9163 - val_mean_absolute_error: 3.7351 - 32ms/epoch - 11ms/step
## Epoch 188/200
## 3/3 - 0s - loss: 34.8812 - mean_absolute_error: 4.3637 - val_loss: 7.6218 - val_mean_absolute_error: 2.5566 - 30ms/epoch - 10ms/step
## Epoch 189/200
## 3/3 - 0s - loss: 26.7710 - mean_absolute_error: 3.2154 - val_loss: 196.7604 - val_mean_absolute_error: 10.4975 - 32ms/epoch - 11ms/step
## Epoch 190/200
## 3/3 - 0s - loss: 297.3047 - mean_absolute_error: 11.0045 - val_loss: 9.2098 - val_mean_absolute_error: 2.6340 - 30ms/epoch - 10ms/step
## Epoch 191/200
## 3/3 - 0s - loss: 11.2494 - mean_absolute_error: 2.5024 - val_loss: 92.5074 - val_mean_absolute_error: 6.8252 - 31ms/epoch - 10ms/step
## Epoch 192/200
## 3/3 - 0s - loss: 83.7759 - mean_absolute_error: 5.3374 - val_loss: 979.8765 - val_mean_absolute_error: 25.2076 - 32ms/epoch - 11ms/step
## Epoch 193/200
## 3/3 - 0s - loss: 8723.0312 - mean_absolute_error: 65.7257 - val_loss: 190.7277 - val_mean_absolute_error: 12.7622 - 32ms/epoch - 11ms/step
## Epoch 194/200
## 3/3 - 0s - loss: 1859.9871 - mean_absolute_error: 25.9576 - val_loss: 14.9970 - val_mean_absolute_error: 3.2210 - 30ms/epoch - 10ms/step
## Epoch 195/200
## 3/3 - 0s - loss: 186.5272 - mean_absolute_error: 6.5522 - val_loss: 874.4487 - val_mean_absolute_error: 26.2726 - 32ms/epoch - 11ms/step
## Epoch 196/200
## 3/3 - 0s - loss: 11187.2139 - mean_absolute_error: 60.0308 - val_loss: 138.4275 - val_mean_absolute_error: 8.6617 - 34ms/epoch - 11ms/step
## Epoch 197/200
## 3/3 - 0s - loss: 2969.7551 - mean_absolute_error: 36.7896 - val_loss: 1227.1619 - val_mean_absolute_error: 30.9068 - 32ms/epoch - 11ms/step
## Epoch 198/200
## 3/3 - 0s - loss: 5218.3994 - mean_absolute_error: 42.2745 - val_loss: 29165.1426 - val_mean_absolute_error: 148.7445 - 35ms/epoch - 12ms/step
## Epoch 199/200
## 3/3 - 0s - loss: 246666.6406 - mean_absolute_error: 295.3921 - val_loss: 46679.2266 - val_mean_absolute_error: 183.0828 - 32ms/epoch - 11ms/step
## Epoch 200/200
## 3/3 - 0s - loss: 176515.6719 - mean_absolute_error: 269.3751 - val_loss: 1371.3434 - val_mean_absolute_error: 30.9480 - 33ms/epoch - 11ms/step
# Evaluate the model on the test data
score <- model %>% evaluate(test_features, test_response, verbose = 0)
score 
##                loss mean_absolute_error 
##          3489.22021            39.01598
# Predict on test data
predictions <- model %>% predict(test_features)
## 1/1 - 0s - 55ms/epoch - 55ms/step
# Calculate RMSE
rmse <- round(sqrt(mean((predictions - test_response)^2)),3)
print(paste("RMSE on test data:", rmse))
## [1] "RMSE on test data: 59.07"
# Calculate correlations
correlation <- round(cor(predictions, test_response),3)
print(paste("Correlation between actual and predicted values:", correlation))
## [1] "Correlation between actual and predicted values: 0.452"
# Output the model structure
plot(history)

print(model)
## Model: "sequential_2"
## ________________________________________________________________________________
##  Layer (type)                       Output Shape                    Param #     
## ================================================================================
##  dense_7 (Dense)                    (None, 256)                     2048        
##  dense_6 (Dense)                    (None, 128)                     32896       
##  dense_5 (Dense)                    (None, 64)                      8256        
##  dense_4 (Dense)                    (None, 1)                       65          
## ================================================================================
## Total params: 43,265
## Trainable params: 43,265
## Non-trainable params: 0
## ________________________________________________________________________________
#install.packages("plotly")
#install.packages("fastmap")
#library(fastmap)
library(plotly)

#epochs <- 50
#time <- 1:epochs
#hist_df <- data.frame(time=time, loss=history$metrics$loss, mae=history$metrics$mean_absolute_error,
#                      valid_loss=history$metrics$val_loss, valid_mae=history$metrics$val_mean_absolute_error)

#plot_ly(hist_df, x = ~time) %>%
#  add_trace(y = ~loss, name = 'training loss', type = "scatter", mode = 'lines') %>%
#  add_trace(y = ~mae, name = 'training MAE', type = "scatter", mode = 'lines+markers') %>%
#  add_trace(y = ~valid_loss, name = 'validation loss', type = "scatter", mode = 'lines+markers') %>%
#  add_trace(y = ~valid_mae, name = 'validation MAE', type = "scatter", mode = 'lines+markers') %>%
#  layout(title = "NN Model Performance",
#         legend = list(orientation = 'h'),
#         yaxis = list(title = "Metric"))

hist_df <- data.frame(cases=200, real=as.matrix(test_response), predicted=predictions)

plot_ly(hist_df, x = ~real)  %>%
  add_trace(y = ~predicted, name = 'Scatter (Actual vs. Predicted)', type="scatter", mode = 'markers') %>%
  add_lines(x = ~real, y = ~fitted(lm(predicted ~ real, hist_df)), name="LM(Pred ~ Real)") %>% 
  layout(title=paste0("NN Model Prediction (correlation=", correlation,")"),
           legend = list(orientation = 'h'), yaxis=list(title="predicted"))
print(paste0("Corr(Actual, Predicted)=", 
             round(cor(predictions, test_response), 3)))
## [1] "Corr(Actual, Predicted)=0.452"
summary(model)
## Model: "sequential_2"
## ________________________________________________________________________________
##  Layer (type)                       Output Shape                    Param #     
## ================================================================================
##  dense_7 (Dense)                    (None, 256)                     2048        
##  dense_6 (Dense)                    (None, 128)                     32896       
##  dense_5 (Dense)                    (None, 64)                      8256        
##  dense_4 (Dense)                    (None, 1)                       65          
## ================================================================================
## Total params: 43,265
## Trainable params: 43,265
## Non-trainable params: 0
## ________________________________________________________________________________

The training process was run for 200 epochs. An epoch is one complete pass through the entire training dataset.

The reported Root Mean Square Error (RMSE) on the test data is 59.07, relatively small considered the RMSE has a range between 0 to 10000, and thus the model could be deemed reasonably accurate.

The correlation of 0.452 suggests a moderate positive linear relationship between the actual and predicted values. This indicates that the model has learned some of the underlying patterns in the data but there’s still a substantial amount of variance that the model is not capturing. A higher correlation would indicate a better fit of the model to the data.

The model “sequential_2” is a densely connected neural network with four layers. The first layer has 256 neurons, the second 128, the third 64, and the final output layer has a single neuron, indicating the model is likely designed for regression or binary classification tasks. The model is fairly complex, with a total of 43,265 parameters, all of which are trainable. This indicates a potentially high capacity for learning from data, but also a risk of overfitting if not enough training data is provided or if proper regularization techniques are not employed. Each layer’s parameters are derived from the connections to all neurons in the preceding layer, along with a bias term for each neuron. The summary confirms that no parameters are frozen or non-trainable, meaning the entire model will be updated during the training process.

3 Image classification

#install.packages("magick")
#install.packages("dplyr")
#py_install("Pillow", envname = "r-reticulate")

library(keras)
library(dplyr)
library(magick)

# Function to classify an image using multiple models
classify_image <- function(img_url) {
  # Download the image
  download.file(img_url, paste(getwd(),"results/image.png", sep="/"), mode = 'wb')
  
  # Read the image and resize
  img <- image_read(paste(getwd(),"results/image.png", sep="/")) %>% image_resize("224x224!")
  img_for_display <- image_read(paste(getwd(),"results/image.png", sep="/")) %>% image_resize("800x800")

  # Preprocess the image for prediction
  x <- as.integer(image_data(img))

  # ensure we have a 4d tensor with single element in the batch dimension,
  # the preprocess the input for prediction using resnet50
  x <- array_reshape(x, c(1, dim(x)))
  x <- imagenet_preprocess_input(x)
  
  # Initialize list to store predictions
  predictions_list <- list()
  
  # Model 1: ResNet50
  model_resnet50 <- application_resnet50(weights = 'imagenet')
  preds_resnet50 <- predict(model_resnet50, x)
  predictions_list$resnet50 <- imagenet_decode_predictions(preds_resnet50, top = 5)[[1]]
  
  # Model 2: VGG19
  model_vgg19 <- application_vgg19(weights = 'imagenet')
  preds_vgg19 <- predict(model_vgg19, x)
  predictions_list$vgg19 <- imagenet_decode_predictions(preds_vgg19, top = 5)[[1]]
  
  # Model 3: VGG16
  model_vgg16 <- application_vgg16(weights = 'imagenet')
  preds_vgg16 <- predict(model_vgg16, x)
  predictions_list$vgg16 <- imagenet_decode_predictions(preds_vgg16, top = 5)[[1]]
   
  
  # Return a list containing both the predictions and the image for display
  return(list(predictions = predictions_list, image = img_for_display))

}

# Use the function to classify an image
Daisy <- classify_image("https://fileinfo.com/img/ss/xl/jpeg_43-2.jpg")
## 1/1 - 1s - 1s/epoch - 1s/step
## 1/1 - 0s - 475ms/epoch - 475ms/step
## 1/1 - 0s - 391ms/epoch - 391ms/step
Volcano <- classify_image("https://media-cldnry.s-nbcnews.com/image/upload/t_fit-1500w,f_auto,q_auto:best/rockcms/2023-05/230522-Mexico-volcano-Popocatepetl-eruption-lava-ac-902p-373199.jpg")
## 1/1 - 1s - 1s/epoch - 1s/step
## 1/1 - 0s - 437ms/epoch - 437ms/step
## 1/1 - 0s - 413ms/epoch - 413ms/step
Brain <- classify_image("https://media.wired.com/photos/59324e5452d99d6b984dd9a0/master/pass/brain1.jpg")
## 1/1 - 1s - 1s/epoch - 1s/step
## 1/1 - 0s - 477ms/epoch - 477ms/step
## 1/1 - 0s - 436ms/epoch - 436ms/step
# Print out the image and predictions for each model
print(Daisy$image)
## # A tibble: 1 × 7
##   format width height colorspace matte filesize density
##   <chr>  <int>  <int> <chr>      <lgl>    <int> <chr>  
## 1 JPEG     800    554 sRGB       FALSE        0 72x72

print(Daisy$predictions)
## $resnet50
##   class_name class_description      score
## 1  n11939491             daisy 0.75317287
## 2  n03457902        greenhouse 0.08539081
## 3  n03930313      picket_fence 0.03835879
## 4  n03782006           monitor 0.01346979
## 5  n04485082            tripod 0.01004972
## 
## $vgg19
##   class_name class_description       score
## 1  n11939491             daisy 0.733606100
## 2  n03991062               pot 0.038622856
## 3  n03891251        park_bench 0.023287510
## 4  n03930313      picket_fence 0.020245489
## 5  n02280649 cabbage_butterfly 0.009900589
## 
## $vgg16
##   class_name class_description      score
## 1  n11939491             daisy 0.52277225
## 2  n11879895          rapeseed 0.19854690
## 3  n03930313      picket_fence 0.04138704
## 4  n02280649 cabbage_butterfly 0.03947129
## 5  n02281406 sulphur_butterfly 0.02848973

ResNet50: Daisy (75.32%), greenhouse (8.54%), picket fence (3.84%), monitor (1.35%), tripod (1.00%).

VGG19: Daisy (73.36%), pot (3.86%), park bench (2.33%), picket fence (2.02%), cabbage butterfly (0.99%).

VGG16: Daisy (52.28%), rapeseed (19.85%), picket fence (4.14%), cabbage butterfly (3.95%), sulphur butterfly (2.85%).

print(Volcano$image)
## # A tibble: 1 × 7
##   format width height colorspace matte filesize density
##   <chr>  <int>  <int> <chr>      <lgl>    <int> <chr>  
## 1 JPEG     800    533 sRGB       FALSE        0 72x72

print(Volcano$predictions)
## $resnet50
##   class_name class_description        score
## 1  n09472597           volcano 1.000000e+00
## 2  n04330267             stove 3.136750e-10
## 3  n09288635            geyser 1.817625e-10
## 4  n03347037       fire_screen 1.810664e-10
## 5  n02939185           caldron 5.460157e-11
## 
## $vgg19
##   class_name class_description        score
## 1  n09472597           volcano 9.999828e-01
## 2  n04456115             torch 1.468843e-05
## 3  n03729826        matchstick 1.558330e-06
## 4  n01443537          goldfish 3.081052e-07
## 5  n04330267             stove 2.524659e-07
## 
## $vgg16
##   class_name class_description        score
## 1  n09472597           volcano 9.997590e-01
## 2  n04456115             torch 1.812786e-04
## 3  n01443537          goldfish 1.262610e-05
## 4  n04330267             stove 1.225619e-05
## 5  n01910747         jellyfish 9.439173e-06

ResNet50: Volcano (100%), stove (0.00000003137%), geyser (0.00000001818%), fire screen (0.00000001811%), caldron (0.00000000546%).

VGG19: Volcano (99.98%), torch (0.0015%), matchstick (0.00016%), goldfish (0.00003%), stove (0.000025%).

VGG16: Volcano (99.76%), torch (0.018%), goldfish (0.0013%), stove (0.0012%), jellyfish (0.00094%).

print(Brain$image)
## # A tibble: 1 × 7
##   format width height colorspace matte filesize density
##   <chr>  <int>  <int> <chr>      <lgl>    <int> <chr>  
## 1 JPEG     800    640 sRGB       FALSE        0 72x72

print(Brain$predictions)
## $resnet50
##   class_name class_description      score
## 1  n01917289       brain_coral 0.06373768
## 2  n03041632           cleaver 0.06237394
## 3  n03627232              knot 0.05447683
## 4  n01930112          nematode 0.05420838
## 5  n03720891            maraca 0.04702024
## 
## $vgg19
##   class_name class_description      score
## 1  n13037406         gyromitra 0.66916829
## 2  n07720875       bell_pepper 0.04610845
## 3  n04599235              wool 0.03687157
## 4  n01917289       brain_coral 0.02477533
## 5  n07695742           pretzel 0.02373043
## 
## $vgg16
##   class_name class_description      score
## 1  n01917289       brain_coral 0.21701127
## 2  n13037406         gyromitra 0.19762117
## 3  n04599235              wool 0.05079585
## 4  n12267677             acorn 0.02914893
## 5  n03840681           ocarina 0.02528968

ResNet50: Brain coral (6.37%), cleaver (6.24%), knot (5.45%), nematode (5.42%), maraca (4.70%).

VGG19: Gyromitra (66.92%), bell pepper (4.61%), wool (3.69%), brain coral (2.48%), pretzel (2.37%).

VGG16: Brain coral (21.70%), gyromitra (19.76%), wool (5.08%), acorn (2.91%), ocarina (2.53%).

The predictions reflect the confidence level of the models in recognizing the objects in the images, with the score representing the probability assigned to each label by the respective models.

3.1 Interpretation

For the Daisy image, all three models correctly identify the daisy with varying degrees of confidence, indicating that the image is likely a clear representation of a daisy and that the models have been well-trained to recognize this class.

For the Volcano image, the models are extremely confident that the image is of a volcano. This high confidence across all models suggests that the image has very distinctive features that are strongly associated with the concept of a volcano.

For the Brain image, the models do not predict a human brain but rather objects with a similar appearance, like brain coral and a type of fungus. This indicates a limitation in the models’ ability to correctly classify this image, which could be due to a lack of representative training data for human brains or the complexity of the image that does not match well with the patterns learned from the ImageNet database.

Based on the available data, ResNet50 seems to be the most reliable, having the highest confidence in the daisy and volcano images, and providing a thematically related prediction for the brain image.VGG19’s top prediction on the brain is less related than the others, which might lean toward it being the less accurate among the three models.

These results show the strengths and limitations of pre-trained ImageNet models when applied to specific images. While they are generally good at recognizing a wide range of objects, their accuracy can vary depending on the similarity of the test images to the training data and the distinctiveness of the image features.